Skip to content

Conversation

@lhutton1
Copy link
Contributor

@lhutton1 lhutton1 commented Oct 15, 2025

This commit adds support for the OCP-MX INT8 type. This includes the following operations: MATMUL_T_BLOCK_SCALED, CAST_FROM_BLOCK_SCALED, CAST_TO_BLOCK_SCALED and CONST.

The support is added via a custom TOSA type "!tosa.mxint8" due to the fact it is not yet a builtin type in mlir. This may change in the future, depending on how this type is used by other frameworks/dialects. Conversions to/from this type have not yet been implemented for the same reasoning.

Note: This PR relies on #156425, #163433, #163436 and #163641 so also contains their contents.

Co-authored-by: Tat Wai Chong tatwai.chong@arm.com

@github-actions
Copy link

github-actions bot commented Oct 15, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

@lhutton1 lhutton1 marked this pull request as ready for review October 24, 2025 15:13
@lhutton1
Copy link
Contributor Author

This is rebased & ready for review.

@llvmbot
Copy link
Member

llvmbot commented Oct 24, 2025

@llvm/pr-subscribers-mlir-tosa

@llvm/pr-subscribers-mlir

Author: Luke Hutton (lhutton1)

Changes

This commit adds support for the OCP-MX INT8 type. This includes the following operations: MATMUL_T_BLOCK_SCALED, CAST_FROM_BLOCK_SCALED, CAST_TO_BLOCK_SCALED and CONST.

The support is added via a custom TOSA type "!tosa.mxint8" due to the fact it is not yet a builtin type in mlir. This may change in the future, depending on how this type is used by other frameworks/dialects. Conversions to/from this type have not yet been implemented for the same reasoning.

Note: This PR relies on #156425, #163433, #163436 and #163641 so also contains their contents.

Co-authored-by: Tat Wai Chong <tatwai.chong@arm.com>


Full diff: https://github.com/llvm/llvm-project/pull/163642.diff

9 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc (+12-5)
  • (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h (+3)
  • (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h (+1-1)
  • (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td (+21-12)
  • (modified) mlir/lib/Dialect/Tosa/IR/TosaOps.cpp (+6)
  • (modified) mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp (+3)
  • (modified) mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp (+4-3)
  • (modified) mlir/test/Dialect/Tosa/ops.mlir (+21)
  • (modified) mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir (+56)
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc b/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc
index 8b5934ff0630e..c774d870a8c45 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc
@@ -572,6 +572,8 @@ extensionComplianceMap = {
         {{fp8e4m3T, fp8ue8m0T, fp8e4m3T, fp8ue8m0T, fp32T},
          SpecificationVersion::V_1_1_DRAFT},
         {{fp8e5m2T, fp8ue8m0T, fp8e5m2T, fp8ue8m0T, fp32T},
+         SpecificationVersion::V_1_1_DRAFT},
+        {{mxint8T, fp8ue8m0T, mxint8T, fp8ue8m0T, fp32T},
          SpecificationVersion::V_1_1_DRAFT}}}}},
     {"tosa.max_pool2d",
      {{{Extension::int16}, {{{i16T, i16T}, SpecificationVersion::V_1_0}}},
@@ -870,14 +872,16 @@ extensionComplianceMap = {
         {{fp6e2m3T, fp8ue8m0T, bf16T}, SpecificationVersion::V_1_1_DRAFT},
         {{fp6e3m2T, fp8ue8m0T, bf16T}, SpecificationVersion::V_1_1_DRAFT},
         {{fp8e4m3T, fp8ue8m0T, bf16T}, SpecificationVersion::V_1_1_DRAFT},
-        {{fp8e5m2T, fp8ue8m0T, bf16T}, SpecificationVersion::V_1_1_DRAFT}},
+        {{fp8e5m2T, fp8ue8m0T, bf16T}, SpecificationVersion::V_1_1_DRAFT},
+        {{mxint8T, fp8ue8m0T, bf16T}, SpecificationVersion::V_1_1_DRAFT}},
        allOf},
       {{Extension::mxfp},
        {{{fp4e2m1T, fp8ue8m0T, fp32T}, SpecificationVersion::V_1_1_DRAFT},
         {{fp6e2m3T, fp8ue8m0T, fp32T}, SpecificationVersion::V_1_1_DRAFT},
         {{fp6e3m2T, fp8ue8m0T, fp32T}, SpecificationVersion::V_1_1_DRAFT},
         {{fp8e4m3T, fp8ue8m0T, fp32T}, SpecificationVersion::V_1_1_DRAFT},
-        {{fp8e5m2T, fp8ue8m0T, fp32T}, SpecificationVersion::V_1_1_DRAFT}}}}},
+        {{fp8e5m2T, fp8ue8m0T, fp32T}, SpecificationVersion::V_1_1_DRAFT},
+        {{mxint8T, fp8ue8m0T, fp32T}, SpecificationVersion::V_1_1_DRAFT}}}}},
     {"tosa.cast_to_block_scaled",
      {{{Extension::mxfp},
        {{{bf16T, fp4e2m1T, fp8ue8m0T}, SpecificationVersion::V_1_1_DRAFT},
@@ -885,12 +889,14 @@ extensionComplianceMap = {
         {{fp32T, fp6e2m3T, fp8ue8m0T}, SpecificationVersion::V_1_1_DRAFT},
         {{fp32T, fp6e3m2T, fp8ue8m0T}, SpecificationVersion::V_1_1_DRAFT},
         {{fp32T, fp8e4m3T, fp8ue8m0T}, SpecificationVersion::V_1_1_DRAFT},
-        {{fp32T, fp8e5m2T, fp8ue8m0T}, SpecificationVersion::V_1_1_DRAFT}}},
+        {{fp32T, fp8e5m2T, fp8ue8m0T}, SpecificationVersion::V_1_1_DRAFT},
+        {{fp32T, mxint8T, fp8ue8m0T}, SpecificationVersion::V_1_1_DRAFT}}},
       {{Extension::bf16, Extension::mxfp},
        {{{bf16T, fp6e2m3T, fp8ue8m0T}, SpecificationVersion::V_1_1_DRAFT},
         {{bf16T, fp6e3m2T, fp8ue8m0T}, SpecificationVersion::V_1_1_DRAFT},
         {{bf16T, fp8e4m3T, fp8ue8m0T}, SpecificationVersion::V_1_1_DRAFT},
-        {{bf16T, fp8e5m2T, fp8ue8m0T}, SpecificationVersion::V_1_1_DRAFT}},
+        {{bf16T, fp8e5m2T, fp8ue8m0T}, SpecificationVersion::V_1_1_DRAFT},
+        {{bf16T, mxint8T, fp8ue8m0T}, SpecificationVersion::V_1_1_DRAFT}},
        allOf}}},
     {"tosa.rescale",
      {{{Extension::int16},
@@ -908,7 +914,8 @@ extensionComplianceMap = {
        {{{fp8ue8m0T}, SpecificationVersion::V_1_1_DRAFT},
         {{fp6e3m2T}, SpecificationVersion::V_1_1_DRAFT},
         {{fp6e2m3T}, SpecificationVersion::V_1_1_DRAFT},
-        {{fp4e2m1T}, SpecificationVersion::V_1_1_DRAFT}}}}},
+        {{fp4e2m1T}, SpecificationVersion::V_1_1_DRAFT},
+        {{mxint8T}, SpecificationVersion::V_1_1_DRAFT}}}}},
     {"tosa.identity",
      {{{Extension::int4}, {{{i4T, i4T}, SpecificationVersion::V_1_0}}},
       {{Extension::int16}, {{{i48T, i48T}, SpecificationVersion::V_1_0}}},
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h
index a15f073bc5fcb..2d4e7cf8b9dbd 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h
@@ -179,6 +179,9 @@ Value createPadConstTensor(OpBuilder &builder, Location loc, Value src,
 // returns type of variable op
 RankedTensorType getVariableType(VariableOp variableOp);
 
+// Returns the bitwidth of a TOSA tensor element type
+unsigned getBitWidth(Type type);
+
 } // namespace tosa
 } // namespace mlir
 
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h b/mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h
index 45d380c1b2e6c..ea58f49b64c44 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h
@@ -70,7 +70,7 @@ class ProfileInfoDepot {
 
 private:
   TypeInfo convertTypeToInfo(Type type) {
-    return {type.getTypeID(), type.getIntOrFloatBitWidth()};
+    return {type.getTypeID(), tosa::getBitWidth(type)};
   }
 
   TypeInfo convertValueToInfo(Value value) {
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
index 93843e86fd378..414b51bf4b135 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
@@ -22,6 +22,12 @@ include "mlir/Dialect/Tosa/IR/TosaOpBase.td"
 // Tosa Type Definitions.
 //===----------------------------------------------------------------------===//
 
+// The base class for Tosa dialect types.
+class Tosa_Type<string name, string typeMnemonic, list<Trait> traits = []>
+    : TypeDef<Tosa_Dialect, name, traits> {
+  let mnemonic = typeMnemonic;
+}
+
 // The base class of a quantized type.
 // Param tuple is: [bitwidth, zeropt, smantissa, sexp, low_end, high_end].
 // Where low and high ends are 0,255 when unsigned, -128,127 when signed, for
@@ -78,13 +84,26 @@ def Tosa_QuantizedInt : AnyTypeOf<[Tosa_QuantizedType<"uint8", [8], 0>,
                                    Tosa_QuantizedType<"int16", [16, 0], 1>,
                                    Tosa_QuantizedType<"int32", [32, 0], 1>]>;
 
+//===----------------------------------------------------------------------===//
+// Custom TOSA element types.
+//===----------------------------------------------------------------------===//
+
+// MLIR doesn't have a builtin type for mxint8 yet. For now declared it as a
+// custom TOSA type. This may be changed in the future.
+def Tosa_MXInt8 : Tosa_Type<"mxint8", "mxint8"> {
+  let summary = "INT8 type as defined by OCP-MX";
+  let description = [{
+    8-bit integer format with an implicit 1/64 scale defined by OCP-MX.
+  }];
+}
+
 //===----------------------------------------------------------------------===//
 // Multi-category types.
 //===----------------------------------------------------------------------===//
-def Tosa_AnyNumber : AnyTypeOf<[Tosa_Int, Tosa_QuantizedInt, AnyFloat],
+def Tosa_AnyNumber : AnyTypeOf<[Tosa_Int, Tosa_QuantizedInt, AnyFloat, Tosa_MXInt8],
                                 "number">;
 
-def Tosa_MXFPNumber : AnyTypeOf<[F8E4M3FN, F8E5M2, F4E2M1FN, F6E2M3FN, F6E3M2FN],
+def Tosa_MXFPNumber : AnyTypeOf<[F8E4M3FN, F8E5M2, F4E2M1FN, F6E2M3FN, F6E3M2FN, Tosa_MXInt8],
                                 "micro-scaling format number">;
 def Tosa_MXFPScaleNumber : AnyTypeOf<[F8E8M0FNU], "micro-scaling format scale number">;
 
@@ -265,16 +284,6 @@ def Tosa_Buffer : MemRefOf<[Tosa_AnyNumber]>;
 def Tosa_TupleBuffer : NestedTupleOf<[Tosa_Buffer]>;
 def Tosa_BufOrTuple : AnyTypeOf<[Tosa_Buffer, Tosa_TupleBuffer]>;
 
-//===----------------------------------------------------------------------===//
-// Tosa Type Definitions.
-//===----------------------------------------------------------------------===//
-
-// The base class for Tosa dialect types.
-class Tosa_Type<string name, string typeMnemonic, list<Trait> traits = []>
-    : TypeDef<Tosa_Dialect, name, traits> {
-  let mnemonic = typeMnemonic;
-}
-
 //===----------------------------------------------------------------------===//
 // ShapeType
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 0aff67f0b5eba..bf3810ff231da 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -606,6 +606,12 @@ Value mlir::tosa::createPadConstTensor(OpBuilder &builder, Location loc,
   return tosa::ConstOp::create(builder, loc, padConstType, padConstAttr);
 }
 
+unsigned mlir::tosa::getBitWidth(Type type) {
+  if (dyn_cast<tosa::mxint8Type>(type))
+    return 8;
+  return type.getIntOrFloatBitWidth();
+}
+
 //===----------------------------------------------------------------------===//
 // TOSA Operator Verifiers.
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
index ab363ee6b4d2a..ddd9c70402fdc 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
@@ -31,6 +31,7 @@ TosaProfileCompliance::TosaProfileCompliance() {
   const TypeInfo fp6e3m2T = {mlir::Float6E3M2FNType::getTypeID(), 6};
   const TypeInfo fp4e2m1T = {mlir::Float4E2M1FNType::getTypeID(), 4};
   const TypeInfo fp8ue8m0T = {mlir::Float8E8M0FNUType::getTypeID(), 8};
+  const TypeInfo mxint8T = {mlir::tosa::mxint8Type::getTypeID(), 8};
 
 // The profile-based compliance content below is auto-generated by a script
 // in https://git.mlplatform.org/tosa/specification.git
@@ -625,6 +626,8 @@ TosaProfileCompliance::stringifyTypeInfo(const TypeInfo &typeInfo) {
     return {"fp4e2m1"};
   } else if (typeInfo.typeID == mlir::Float8E8M0FNUType::getTypeID()) {
     return {"fp8e8m0"};
+  } else if (typeInfo.typeID == tosa::mxint8Type::getTypeID()) {
+    return {"mxint8"};
   }
   llvm_unreachable("unknown type");
 }
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
index 4d0b61acc4ea4..9676ea5ca4868 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
@@ -693,7 +693,7 @@ LogicalResult TosaValidation::levelCheckSize(Operation *op,
                                  << " shape dimension cannot be dynamic";
     }
 
-    int64_t element_bits = type.getElementTypeBitWidth();
+    int64_t element_bits = tosa::getBitWidth(getElementTypeOrSelf(type));
     int64_t element_bytes = std::max(INT64_C(1), element_bits / 8);
     int64_t size = element_bytes * type.getNumElements();
 
@@ -1217,9 +1217,10 @@ bool TosaValidation::isValidElementType(Type type, const bool allowUnsigned) {
         return true;
       }
     }
-  } else if (mlir::isa<tosa::shapeType>(type)) {
+  } else if (mlir::isa<tosa::shapeType>(type))
+    return true;
+  else if (isa<tosa::mxint8Type>(type))
     return true;
-  }
   return false;
 }
 
diff --git a/mlir/test/Dialect/Tosa/ops.mlir b/mlir/test/Dialect/Tosa/ops.mlir
index 865f712ce1a5a..22fde3b7d28a5 100644
--- a/mlir/test/Dialect/Tosa/ops.mlir
+++ b/mlir/test/Dialect/Tosa/ops.mlir
@@ -1269,6 +1269,13 @@ func.func @test_matmul_t_block_scaled_broadcast(%arg0: tensor<?x8x32xf8E4M3FN>,
   return %0 : tensor<4x8x16xf32>
 }
 
+// -----
+// CHECK-LABEL: test_matmul_t_block_scaled_mxint8
+func.func @test_matmul_t_block_scaled_mxint8(%arg0: tensor<4x8x32x!tosa.mxint8>, %arg1: tensor<4x8x1xf8E8M0FNU>, %arg2: tensor<4x16x32x!tosa.mxint8>, %arg3: tensor<4x16x1xf8E8M0FNU>) -> tensor<4x8x16xf32> {
+  %0 = tosa.matmul_t_block_scaled %arg0, %arg1, %arg2, %arg3 {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32} : (tensor<4x8x32x!tosa.mxint8>, tensor<4x8x1xf8E8M0FNU>, tensor<4x16x32x!tosa.mxint8>, tensor<4x16x1xf8E8M0FNU>) -> tensor<4x8x16xf32>
+  return %0 : tensor<4x8x16xf32>
+}
+
 // -----
 // CHECK-LABEL: test_cast_from_block_scaled_static
 func.func @test_cast_from_block_scaled_static(%arg0: tensor<4x32xf4E2M1FN>, %arg1: tensor<4x1xf8E8M0FNU>) -> tensor<4x32xf32> {
@@ -1296,3 +1303,17 @@ func.func @test_cast_to_block_scaled_unranked(%arg0: tensor<*xf32>) -> (tensor<*
   %0:2 = tosa.cast_to_block_scaled %arg0 {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<*xf32>) -> (tensor<*xf4E2M1FN>, tensor<*xf8E8M0FNU>)
   return %0#0, %0#1 : tensor<*xf4E2M1FN>, tensor<*xf8E8M0FNU>
 }
+
+// -----
+// CHECK-LABEL: test_cast_to_block_scaled_mxint8
+func.func @test_cast_to_block_scaled_mxint8(%arg0: tensor<4x32xf32>) -> (tensor<4x32x!tosa.mxint8>, tensor<4x1xf8E8M0FNU>) {
+  %0:2 = tosa.cast_to_block_scaled %arg0 {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32, stochastic_round = false} : (tensor<4x32xf32>) -> (tensor<4x32x!tosa.mxint8>, tensor<4x1xf8E8M0FNU>)
+  return %0#0, %0#1 : tensor<4x32x!tosa.mxint8>, tensor<4x1xf8E8M0FNU>
+}
+
+// -----
+// CHECK-LABEL: test_const_mxint8
+func.func @test_const_mxint8(%arg0 : index) -> tensor<2x!tosa.mxint8> {
+    %0 = "tosa.const"() {values = dense<"0x007F"> : tensor<2x!tosa.mxint8>} : () -> tensor<2x!tosa.mxint8>
+    return %0 : tensor<2x!tosa.mxint8>
+}
diff --git a/mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir b/mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir
index f3d8dab2f6b0f..d8cbaa2c356c3 100644
--- a/mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir
@@ -82,3 +82,59 @@ func.func @test_cast_to_block_scaled_static(%arg0: tensor<4x32xf32>) -> (tensor<
   %0:2 = tosa.cast_to_block_scaled %arg0 {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<4x32xf32>) -> (tensor<4x32xf6E3M2FN>, tensor<4x1xf8E8M0FNU>)
   return %0#0, %0#1 : tensor<4x32xf6E3M2FN>, tensor<4x1xf8E8M0FNU>
 }
+
+// -----
+
+// CHECK-LABEL: test_cast_to_block_scaled_mxint8
+func.func @test_cast_to_block_scaled_mxint8(%arg0: tensor<4x32xf32>) -> (tensor<4x32x!tosa.mxint8>, tensor<4x1xf8E8M0FNU>) {
+  %0:2 = tosa.cast_to_block_scaled %arg0 {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32, stochastic_round = false} : (tensor<4x32xf32>) -> (tensor<4x32x!tosa.mxint8>, tensor<4x1xf8E8M0FNU>)
+  return %0#0, %0#1 : tensor<4x32x!tosa.mxint8>, tensor<4x1xf8E8M0FNU>
+}
+
+// -----
+
+// CHECK-LABEL: test_const_fp6e3m2
+func.func @test_const_fp6e3m2(%arg0 : index) -> tensor<4xf6E3M2FN> {
+    %0 = "tosa.const"() {values = dense<[0.0, 0.0, 0.0, 0.0]> : tensor<4xf6E3M2FN>} : () -> tensor<4xf6E3M2FN>
+    return %0 : tensor<4xf6E3M2FN>
+}
+
+// -----
+
+// CHECK-LABEL: test_const_mxint8
+func.func @test_const_mxint8(%arg0 : index) -> tensor<2x!tosa.mxint8> {
+    %0 = "tosa.const"() {values = dense<"0x007F"> : tensor<2x!tosa.mxint8>} : () -> tensor<2x!tosa.mxint8>
+    return %0 : tensor<2x!tosa.mxint8>
+}
+
+// -----
+
+// CHECK-LABEL: test_cast_f4e2m1
+func.func @test_cast_f4e2m1(%arg0: tensor<13x21x3xf4E2M1FN>) -> tensor<13x21x3xbf16> {
+  %0 = tosa.cast %arg0 : (tensor<13x21x3xf4E2M1FN>) -> tensor<13x21x3xbf16>
+  return %0 : tensor<13x21x3xbf16>
+}
+
+// -----
+
+// CHECK-LABEL: test_matmul_t_block_scaled_mxint8
+func.func @test_matmul_t_block_scaled_mxint8(%arg0: tensor<4x8x32x!tosa.mxint8>, %arg1: tensor<4x8x1xf8E8M0FNU>, %arg2: tensor<4x16x32x!tosa.mxint8>, %arg3: tensor<4x16x1xf8E8M0FNU>) -> tensor<4x8x16xf32> {
+  %0 = tosa.matmul_t_block_scaled %arg0, %arg1, %arg2, %arg3 {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<4x8x32x!tosa.mxint8>, tensor<4x8x1xf8E8M0FNU>, tensor<4x16x32x!tosa.mxint8>, tensor<4x16x1xf8E8M0FNU>) -> tensor<4x8x16xf32>
+  return %0 : tensor<4x8x16xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_cast_to_block_scaled_mxint8
+func.func @test_cast_to_block_scaled_mxint8(%arg0: tensor<4x32xf32>) -> (tensor<4x32x!tosa.mxint8>, tensor<4x1xf8E8M0FNU>) {
+  %0:2 = tosa.cast_to_block_scaled %arg0 {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32, stochastic_round = false} : (tensor<4x32xf32>) -> (tensor<4x32x!tosa.mxint8>, tensor<4x1xf8E8M0FNU>)
+  return %0#0, %0#1 : tensor<4x32x!tosa.mxint8>, tensor<4x1xf8E8M0FNU>
+}
+
+// -----
+
+// CHECK-LABEL: test_const_mxint8
+func.func @test_const_mxint8(%arg0 : index) -> tensor<2x!tosa.mxint8> {
+    %0 = "tosa.const"() {values = dense<"0x007F"> : tensor<2x!tosa.mxint8>} : () -> tensor<2x!tosa.mxint8>
+    return %0 : tensor<2x!tosa.mxint8>
+}

@psunn
Copy link
Contributor

psunn commented Oct 28, 2025

Thanks for keep it updated. I will help to review later today.

@psunn psunn self-requested a review October 28, 2025 10:14
lhutton1 and others added 2 commits October 28, 2025 11:50
This commit adds support for the OCP-MX INT8 type. This includes the
following operations: MATMUL_T_BLOCK_SCALED, CAST_FROM_BLOCK_SCALED,
CAST_TO_BLOCK_SCALED and CONST.

The support is added via a custom TOSA type "!tosa.mxint8" due to the
fact it is not yet a builtin type in mlir. This may change in the
future, depending on how this type is used by other frameworks/
dialects. Conversions to/from this type have not yet been implemented
for the same reasoning.

Co-authored-by: Tat Wai Chong <tatwai.chong@arm.com>
Change-Id: I6dbba8d55075111cae6b3186cef90fd87d9e5ae6
* remove duplicate test
* remove unused input parameter
* use a list of strings, rather than one string with hex values

Change-Id: I344eb12ed3b6cbd6d5cca04647667296f49e9df4
Copy link
Contributor

@psunn psunn left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks reasonable to me. The added MXINT8 type and related operators align with the TOSA specification, and the test cases have been added.

Nit: please squash the current two commits into one before merging.

Change-Id: Iaa685661cc34d5a738260821ecfaff12f63812b4
@lhutton1
Copy link
Contributor Author

Thanks for the review. The commits are automatically squashed on merge :)

@lhutton1 lhutton1 merged commit 334d438 into llvm:main Oct 29, 2025
10 checks passed
@lhutton1 lhutton1 deleted the mxfp-mxint8 branch October 29, 2025 10:59
@llvm-ci
Copy link
Collaborator

llvm-ci commented Oct 29, 2025

LLVM Buildbot has detected a new failure on builder mlir-nvidia running on mlir-nvidia while building mlir at step 7 "test-build-check-mlir-build-only-check-mlir".

Full details are available at: https://lab.llvm.org/buildbot/#/builders/138/builds/21053

Here is the relevant piece of the build log for the reference
Step 7 (test-build-check-mlir-build-only-check-mlir) failure: test (failure)
******************** TEST 'MLIR :: Integration/GPU/CUDA/async.mlir' FAILED ********************
Exit Code: 1

Command Output (stdout):
--
# RUN: at line 1
/vol/worker/mlir-nvidia/mlir-nvidia/llvm.obj/bin/mlir-opt /vol/worker/mlir-nvidia/mlir-nvidia/llvm.src/mlir/test/Integration/GPU/CUDA/async.mlir  | /vol/worker/mlir-nvidia/mlir-nvidia/llvm.obj/bin/mlir-opt -gpu-kernel-outlining  | /vol/worker/mlir-nvidia/mlir-nvidia/llvm.obj/bin/mlir-opt -pass-pipeline='builtin.module(gpu.module(strip-debuginfo,convert-gpu-to-nvvm),nvvm-attach-target)'  | /vol/worker/mlir-nvidia/mlir-nvidia/llvm.obj/bin/mlir-opt -gpu-async-region -gpu-to-llvm -reconcile-unrealized-casts -gpu-module-to-binary="format=fatbin"  | /vol/worker/mlir-nvidia/mlir-nvidia/llvm.obj/bin/mlir-opt -async-to-async-runtime -async-runtime-ref-counting  | /vol/worker/mlir-nvidia/mlir-nvidia/llvm.obj/bin/mlir-opt -convert-async-to-llvm -convert-func-to-llvm -convert-arith-to-llvm -convert-cf-to-llvm -reconcile-unrealized-casts  | /vol/worker/mlir-nvidia/mlir-nvidia/llvm.obj/bin/mlir-runner    --shared-libs=/vol/worker/mlir-nvidia/mlir-nvidia/llvm.obj/lib/libmlir_cuda_runtime.so    --shared-libs=/vol/worker/mlir-nvidia/mlir-nvidia/llvm.obj/lib/libmlir_async_runtime.so    --shared-libs=/vol/worker/mlir-nvidia/mlir-nvidia/llvm.obj/lib/libmlir_runner_utils.so    --entry-point-result=void -O0  | /vol/worker/mlir-nvidia/mlir-nvidia/llvm.obj/bin/FileCheck /vol/worker/mlir-nvidia/mlir-nvidia/llvm.src/mlir/test/Integration/GPU/CUDA/async.mlir
# executed command: /vol/worker/mlir-nvidia/mlir-nvidia/llvm.obj/bin/mlir-opt /vol/worker/mlir-nvidia/mlir-nvidia/llvm.src/mlir/test/Integration/GPU/CUDA/async.mlir
# executed command: /vol/worker/mlir-nvidia/mlir-nvidia/llvm.obj/bin/mlir-opt -gpu-kernel-outlining
# executed command: /vol/worker/mlir-nvidia/mlir-nvidia/llvm.obj/bin/mlir-opt '-pass-pipeline=builtin.module(gpu.module(strip-debuginfo,convert-gpu-to-nvvm),nvvm-attach-target)'
# executed command: /vol/worker/mlir-nvidia/mlir-nvidia/llvm.obj/bin/mlir-opt -gpu-async-region -gpu-to-llvm -reconcile-unrealized-casts -gpu-module-to-binary=format=fatbin
# executed command: /vol/worker/mlir-nvidia/mlir-nvidia/llvm.obj/bin/mlir-opt -async-to-async-runtime -async-runtime-ref-counting
# executed command: /vol/worker/mlir-nvidia/mlir-nvidia/llvm.obj/bin/mlir-opt -convert-async-to-llvm -convert-func-to-llvm -convert-arith-to-llvm -convert-cf-to-llvm -reconcile-unrealized-casts
# executed command: /vol/worker/mlir-nvidia/mlir-nvidia/llvm.obj/bin/mlir-runner --shared-libs=/vol/worker/mlir-nvidia/mlir-nvidia/llvm.obj/lib/libmlir_cuda_runtime.so --shared-libs=/vol/worker/mlir-nvidia/mlir-nvidia/llvm.obj/lib/libmlir_async_runtime.so --shared-libs=/vol/worker/mlir-nvidia/mlir-nvidia/llvm.obj/lib/libmlir_runner_utils.so --entry-point-result=void -O0
# .---command stderr------------
# | 'cuStreamWaitEvent(stream, event, 0)' failed with 'CUDA_ERROR_CONTEXT_IS_DESTROYED'
# | 'cuEventDestroy(event)' failed with 'CUDA_ERROR_CONTEXT_IS_DESTROYED'
# | 'cuStreamWaitEvent(stream, event, 0)' failed with 'CUDA_ERROR_CONTEXT_IS_DESTROYED'
# | 'cuEventDestroy(event)' failed with 'CUDA_ERROR_CONTEXT_IS_DESTROYED'
# | 'cuStreamWaitEvent(stream, event, 0)' failed with 'CUDA_ERROR_CONTEXT_IS_DESTROYED'
# | 'cuStreamWaitEvent(stream, event, 0)' failed with 'CUDA_ERROR_CONTEXT_IS_DESTROYED'
# | 'cuEventDestroy(event)' failed with 'CUDA_ERROR_CONTEXT_IS_DESTROYED'
# | 'cuEventDestroy(event)' failed with 'CUDA_ERROR_CONTEXT_IS_DESTROYED'
# | 'cuEventSynchronize(event)' failed with 'CUDA_ERROR_CONTEXT_IS_DESTROYED'
# | 'cuEventDestroy(event)' failed with 'CUDA_ERROR_CONTEXT_IS_DESTROYED'
# `-----------------------------
# executed command: /vol/worker/mlir-nvidia/mlir-nvidia/llvm.obj/bin/FileCheck /vol/worker/mlir-nvidia/mlir-nvidia/llvm.src/mlir/test/Integration/GPU/CUDA/async.mlir
# .---command stderr------------
# | /vol/worker/mlir-nvidia/mlir-nvidia/llvm.src/mlir/test/Integration/GPU/CUDA/async.mlir:68:12: error: CHECK: expected string not found in input
# |  // CHECK: [84, 84]
# |            ^
# | <stdin>:1:1: note: scanning from here
# | Unranked Memref base@ = 0x5ada408feac0 rank = 1 offset = 0 sizes = [2] strides = [1] data = 
# | ^
# | <stdin>:2:1: note: possible intended match here
# | [42, 42]
# | ^
# | 
# | Input file: <stdin>
# | Check file: /vol/worker/mlir-nvidia/mlir-nvidia/llvm.src/mlir/test/Integration/GPU/CUDA/async.mlir
# | 
# | -dump-input=help explains the following input dump.
# | 
# | Input was:
# | <<<<<<
# |             1: Unranked Memref base@ = 0x5ada408feac0 rank = 1 offset = 0 sizes = [2] strides = [1] data =  
# | check:68'0     X~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ error: no match found
# |             2: [42, 42] 
# | check:68'0     ~~~~~~~~~
# | check:68'1     ?         possible intended match
...

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants